#!/usr/bin/env python3
"""
canonicalize_kernels.py
=======================

This script converts the gauge‑group kernel files provided in
``inputs/kernels`` into a canonical format more amenable to analysis and
simulation.  In the raw format the FPHS kernels are stored as a 3‑D
array of shape ``(2*L*L, N, N)`` where ``N`` is the dimension of the
gauge group (``N=2`` for SU2, ``N=3`` for SU3).  Each entry of this
array corresponds to an individual lattice link and contains a
``N×N`` matrix.  The first ``L*L`` matrices store the link in the
``x``‑direction at each lattice site, and the next ``L*L`` matrices
store the link in the ``y``‑direction.

While this matrix‑valued representation is convenient for gauge
calculations, the kernel‑to‑metric translator only depends on the
amplitude of the links.  A common reduction is to take the Frobenius
norm (square root of the sum of squares of all matrix entries) of
each matrix, producing a real scalar for each link.  The resulting
vector of length ``2*L*L`` is then reshaped into a 3‑D array of
shape ``(2, L, L)`` where the first axis enumerates ``x`` and
``y`` links.

Running this script will locate all kernel files under a specified
``--in`` directory (default: ``inputs/kernels``), canonicalise them
using the Frobenius norm as described above, and write the resulting
arrays into a parallel directory specified by ``--out`` (default:
``inputs/kernels_canonical``) preserving the gauge and lattice size
subdirectories.  Existing files in the output directory will be
overwritten.

Example usage::

    python scripts/canonicalize_kernels.py --in inputs/kernels --out inputs/kernels_canonical

After running, the canonical kernels can be loaded with
``np.load(..., mmap_mode="r")`` and will have dtype ``float64`` and
shape ``(2, L, L)``.

"""

from __future__ import annotations

import argparse
import os
import numpy as np
from typing import Tuple


def canonicalise_kernel(kernel: np.ndarray, L: int) -> np.ndarray:
    """Convert a raw kernel of shape ``(2*L*L, N, N)`` into
    ``(2, L, L)`` using Frobenius norms.

    Parameters
    ----------
    kernel : np.ndarray
        Raw kernel array of shape ``(2*L*L, N, N)``.  The first
        ``L*L`` entries correspond to x‑links; the next ``L*L``
        entries to y‑links.
    L : int
        Lattice size; must satisfy ``kernel.shape[0] == 2*L*L``.

    Returns
    -------
    np.ndarray
        Canonical kernel of shape ``(2, L, L)``.
    """
    if kernel.ndim != 3:
        raise ValueError(f"expected 3‑D array, got shape {kernel.shape}")
    if kernel.shape[0] != 2 * L * L:
        raise ValueError(
            f"kernel first dimension {kernel.shape[0]} != 2*{L}*{L}"
        )
    # Compute Frobenius norm of each N×N block
    norms = np.linalg.norm(kernel, ord="fro", axis=(1, 2))
    # Split into x and y components and reshape
    x_links = norms[: L * L].reshape(L, L)
    y_links = norms[L * L :].reshape(L, L)
    # Stack into (2, L, L)
    return np.stack([x_links, y_links], axis=0).astype(np.float64)


def process_file(path_in: str, path_out: str) -> None:
    """Load a kernel file, canonicalise it and write to ``path_out``."""
    arr = np.load(path_in, allow_pickle=False)
    if arr.ndim == 1:
        # Already canonical 1‑D vector; reshape to (2, L, L)
        # Determine L: length = 2*L*L
        N = arr.size
        L = int((N / 2) ** 0.5)
        if 2 * L * L != N:
            raise ValueError(
                f"cannot infer L from vector of length {N} (expected 2*L^2)"
            )
        kvec = arr.reshape(2, L, L)
        out = kvec.astype(np.float64)
    elif arr.ndim == 3:
        # Determine L from first dimension
        N = arr.shape[0]
        L = int((N / 2) ** 0.5)
        if 2 * L * L != N:
            raise ValueError(
                f"first dim {N} != 2*L^2 for some integer L"
            )
        out = canonicalise_kernel(arr, L)
    else:
        raise ValueError(
            f"unsupported array shape {arr.shape} in {path_in}"
        )
    # Write out
    os.makedirs(os.path.dirname(path_out), exist_ok=True)
    np.save(path_out, out)


def main() -> None:
    parser = argparse.ArgumentParser(description="Canonicalise FPHS kernels using Frobenius norms.")
    parser.add_argument(
        "--in",
        dest="in_dir",
        default="inputs/kernels",
        help="Input directory containing gauge subdirectories with kernel files",
    )
    parser.add_argument(
        "--out",
        dest="out_dir",
        default="inputs/kernels_canonical",
        help="Output directory to write canonical kernels",
    )
    args = parser.parse_args()

    in_root = args.in_dir
    out_root = args.out_dir

    for gauge in os.listdir(in_root):
        g_dir = os.path.join(in_root, gauge)
        if not os.path.isdir(g_dir):
            continue
        for fname in os.listdir(g_dir):
            if not fname.endswith(".npy"):
                continue
            in_path = os.path.join(g_dir, fname)
            out_path = os.path.join(out_root, gauge, fname)
            try:
                process_file(in_path, out_path)
                print(f"Converted {in_path} → {out_path}")
            except Exception as exc:
                print(f"[WARN] Skipped {in_path}: {exc}")


if __name__ == "__main__":
    main()